"""
This example trains a SparseEncoder for the SNLI + MultiNLI (AllNLI) dataset.
The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval.
It loads a subset of the AllNLI dataset, splits it into training and evaluation subsets,
and he model is evaluated on the STS benchmark dataset. After training, the model is evaluated and
saved locally, with an optional step to push the trained model to the Hugging Face Hub.

Usage:
python train_splade_nli.py
"""

import logging
import traceback

from datasets import load_dataset

from sentence_transformers import (
    SparseEncoder,
    SparseEncoderModelCardData,
    SparseEncoderTrainer,
    SparseEncoderTrainingArguments,
)
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.sparse_encoder import evaluation, losses
from sentence_transformers.training_args import BatchSamplers

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


def main():
    model_name = "naver/splade-cocondenser-ensembledistil"
    short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]

    train_batch_size = 16
    num_epochs = 1

    # 1a. Load a model to finetune with 1b. (Optional) model card data
    model = SparseEncoder(
        model_name,
        model_card_data=SparseEncoderModelCardData(
            language="en",
            license="apache-2.0",
            model_name=f"{short_model_name} trained on Natural Language Inference (NLI)",
        ),
        similarity_fn_name="dot",  # or cosine but dot tends to be stronger
    )
    model.max_seq_length = 256  # Set the max sequence length to 256 for the training
    logging.info("Model max length: %s", model.max_seq_length)

    # 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
    # We'll start with 10k training samples, but you can increase this to get a stronger model
    logging.info("Read AllNLI train dataset")
    train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train").select(range(10000))
    eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="dev").select(range(1000))
    logging.info(train_dataset)
    logging.info(eval_dataset)

    # 3. Define our training loss.
    document_regularizer_weight = 3e-3

    loss = losses.SpladeLoss(
        model=model,
        loss=losses.SparseMultipleNegativesRankingLoss(
            model=model,
            scale=1,  # need to be adapt if used cosine similarity
            similarity_fct=model.similarity,  # Use the same similarity function as the model
        ),
        document_regularizer_weight=document_regularizer_weight,  # Weight for document loss
        use_document_regularizer_only=True,
    )

    # 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
    stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
    dev_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator(
        sentences1=stsb_eval_dataset["sentence1"],
        sentences2=stsb_eval_dataset["sentence2"],
        scores=stsb_eval_dataset["score"],
        main_similarity=SimilarityFunction.COSINE,
        name="sts-dev",
    )
    dev_evaluator(model)

    # 5. Define the training arguments
    run_name = f"{short_model_name}-nli"
    training_args = SparseEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=train_batch_size,
        learning_rate=4e-6,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        load_best_model_at_end=True,
        metric_for_best_model="eval_sts-dev_spearman_cosine",
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=120,
        save_strategy="steps",
        save_steps=120,
        save_total_limit=2,
        logging_steps=20,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
        seed=42,
    )

    # 6. Create the trainer & start training
    trainer = SparseEncoderTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=dev_evaluator,
    )
    trainer.train()

    # 7. Evaluate the model performance on the STS Benchmark test dataset
    test_dataset = load_dataset("sentence-transformers/stsb", split="test")
    test_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator(
        sentences1=test_dataset["sentence1"],
        sentences2=test_dataset["sentence2"],
        scores=test_dataset["score"],
        main_similarity=SimilarityFunction.COSINE,
        name="sts-test",
    )
    test_evaluator(model)

    # 8. Save the final model
    final_output_dir = f"models/{run_name}/final"
    model.save_pretrained(final_output_dir)

    # 9. (Optional) save the model to the Hugging Face Hub!
    # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
    try:
        model.push_to_hub(run_name)
    except Exception:
        logging.error(
            f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
            f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
            f"and saving it using `model.push_to_hub('{run_name}')`."
        )


if __name__ == "__main__":
    main()
